{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.5"
    },
    "colab": {
      "name": "charsiu_demo.ipynb",
      "provenance": []
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "0zxKOeyTROc2"
      },
      "source": [
        "!pip install torch torchvision torchaudio\n",
        "!pip install datasets transformers\n",
        "!pip install g2p_en praatio librosa"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "DIROcsj7Rv4g",
        "outputId": "290c10ae-5ee4-4b9f-86de-8de075c80a76"
      },
      "source": [
        "import os\n",
        "from os.path import exists, join, expanduser\n",
        "\n",
        "os.chdir(expanduser(\"~\"))\n",
        "charsiu_dir = 'charsiu'\n",
        "if exists(charsiu_dir):\n",
        "  !rm -rf /root/charsiu\n",
        "if not exists(charsiu_dir):\n",
        "  ! git clone -b development https://github.com/lingjzhu/$charsiu_dir\n",
        "  ! cd charsiu && git checkout && cd -\n",
        "  \n",
        "os.chdir(charsiu_dir)    "
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Cloning into 'charsiu'...\n",
            "remote: Enumerating objects: 322, done.\u001b[K\n",
            "remote: Counting objects: 100% (322/322), done.\u001b[K\n",
            "remote: Compressing objects: 100% (260/260), done.\u001b[K\n",
            "remote: Total 322 (delta 158), reused 159 (delta 55), pack-reused 0\u001b[K\n",
            "Receiving objects: 100% (322/322), 511.27 KiB | 12.47 MiB/s, done.\n",
            "Resolving deltas: 100% (158/158), done.\n",
            "Your branch is up to date with 'origin/development'.\n",
            "/root\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GmHNb4OxRVD8"
      },
      "source": [
        "import sys\n",
        "import torch\n",
        "from itertools import groupby\n",
        "from datasets import load_dataset\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "sys.path.insert(0,'src')\n",
        "from Charsiu import charsiu_chain_attention_aligner, charsiu_chain_forced_aligner, charsiu_predictive_aligner"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "q7paWfYdROc5"
      },
      "source": [
        "# download timit\n",
        "timit = load_dataset('timit_asr')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "psXcfdsd48NJ",
        "outputId": "8b13c545-1e59-4e4e-cb4e-d3b021a88e3f",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "# load data\n",
        "sample = timit['train'][0]\n",
        "text = sample['text']\n",
        "audio_path = sample['file']\n",
        "print('Text transcription:%s'%(text))\n",
        "print('Audio path: %s'%audio_path)"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Text transcription:Would such an act of refusal be useful?\n",
            "Audio path: /root/.cache/huggingface/datasets/downloads/extracted/404950a46da14eac65eb4e2a8317b1372fb3971d980d91d5d5b221275b1fd7e0/data/TRAIN/DR4/MMDM0/SI681.WAV\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gUye9Hgpzpxb"
      },
      "source": [
        "Phone recognizer + Neural Forced Alignment"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yHW92QgDROc4"
      },
      "source": [
        "# load model\n",
        "charsiu = charsiu_chain_attention_aligner(aligner='charsiu/en_w2v2_fs_10ms',recognizer='charsiu/en_w2v2_ctc_libris_and_cv')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gJkF2z91ROc5"
      },
      "source": [
        "alignment = charsiu.align(audio=audio_path)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gLE1r9LQ5CYq",
        "outputId": "cdd35c8f-620b-4d04-d51e-bf37d9ebcf21",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "print(alignment)\n",
        "print('\\n Ground Truth \\n')\n",
        "print([(s/16000,e/16000,p) for s,e,p in zip(sample['phonetic_detail']['start'],sample['phonetic_detail']['stop'],sample['phonetic_detail']['utterance'])])"
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[(0.0, 0.11, '[SIL]'), (0.11, 0.16, 'W'), (0.16, 0.21, 'IH'), (0.21, 0.27, 'DH'), (0.27, 0.38, 'S'), (0.38, 0.49, 'AH'), (0.49, 0.58, 'CH'), (0.58, 0.63, 'AE'), (0.63, 0.69, 'N'), (0.69, 0.83, 'AE'), (0.83, 0.94, 'K'), (0.94, 1.0, 'T'), (1.0, 1.04, 'IH'), (1.04, 1.12, 'V'), (1.12, 1.18, 'R'), (1.18, 1.24, 'AH'), (1.24, 1.34, 'F'), (1.34, 1.43, 'Y'), (1.43, 1.5, 'UW'), (1.5, 1.58, 'Z'), (1.58, 1.64, 'AH'), (1.64, 1.73, 'L'), (1.73, 1.79, 'B'), (1.79, 1.9, 'IY'), (1.9, 2.01, 'Y'), (2.01, 2.09, 'UW'), (2.09, 2.17, 'S'), (2.17, 2.24, 'F'), (2.24, 2.31, 'AH'), (2.31, 2.4, 'L'), (2.4, 2.48, '[SIL]')]\n",
            "\n",
            " Ground Truth \n",
            "\n",
            "[(0.0, 0.1225, 'h#'), (0.1225, 0.154125, 'w'), (0.154125, 0.2175, 'ix'), (0.2175, 0.25, 'dcl'), (0.25, 0.3725, 's'), (0.3725, 0.4675, 'ah'), (0.4675, 0.4925, 'tcl'), (0.4925, 0.5875, 'ch'), (0.5875, 0.6225, 'ix'), (0.6225, 0.6675, 'n'), (0.6675, 0.8425, 'ae'), (0.8425, 0.98, 'kcl'), (0.98, 0.9925, 't'), (0.9925, 1.0575, 'ix'), (1.0575, 1.1435625, 'v'), (1.1435625, 1.180125, 'r'), (1.180125, 1.2175, 'ix'), (1.2175, 1.3576875, 'f'), (1.3576875, 1.40725, 'y'), (1.40725, 1.5025, 'ux'), (1.5025, 1.574375, 'zh'), (1.574375, 1.6925, 'el'), (1.6925, 1.76, 'bcl'), (1.76, 1.785, 'b'), (1.785, 1.8825, 'iy'), (1.8825, 1.9895, 'y'), (1.9895, 2.0775, 'ux'), (2.0775, 2.165, 's'), (2.165, 2.248, 'f'), (2.248, 2.3575, 'el'), (2.3575, 2.495, 'h#')]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5R2M4YMHUf-X"
      },
      "source": [
        "charsiu.serve(audio=audio_path, save_to='sample.TextGrid')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yCmbdfpzXrQ3"
      },
      "source": [
        "# load model\n",
        "charsiu = charsiu_chain_forced_aligner(aligner='charsiu/en_w2v2_fc_10ms',recognizer='charsiu/en_w2v2_ctc_libris_and_cv')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SdZqWsE45Swv"
      },
      "source": [
        "alignment = charsiu.align(audio=audio_path)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "68sFGMxN5Y7b",
        "outputId": "493ed56f-64fb-43a4-ed64-e5257896a2b0",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "print(alignment)\n",
        "print('\\n Ground Truth \\n')\n",
        "print([(s/16000,e/16000,p) for s,e,p in zip(sample['phonetic_detail']['start'],sample['phonetic_detail']['stop'],sample['phonetic_detail']['utterance'])])"
      ],
      "execution_count": 24,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[(0.0, 0.08, '[SIL]'), (0.08, 0.16, 'W'), (0.16, 0.21, 'IH'), (0.21, 0.22, 'DH'), (0.22, 0.38, 'S'), (0.38, 0.46, 'AH'), (0.46, 0.58, 'CH'), (0.58, 0.6, 'AE'), (0.6, 0.68, 'N'), (0.68, 0.82, 'AE'), (0.82, 0.93, 'K'), (0.93, 0.99, 'T'), (0.99, 1.05, 'IH'), (1.05, 1.13, 'V'), (1.13, 1.17, 'R'), (1.17, 1.22, 'AH'), (1.22, 1.33, 'F'), (1.33, 1.41, 'Y'), (1.41, 1.48, 'UW'), (1.48, 1.53, 'Z'), (1.53, 1.62, 'AH'), (1.62, 1.68, 'L'), (1.68, 1.78, 'B'), (1.78, 1.88, 'IY'), (1.88, 1.99, 'Y'), (1.99, 2.08, 'UW'), (2.08, 2.13, 'S'), (2.13, 2.23, 'F'), (2.23, 2.27, 'AH'), (2.27, 2.47, 'L'), (2.47, 2.48, '[SIL]')]\n",
            "\n",
            " Ground Truth \n",
            "\n",
            "[(0.0, 0.1225, 'h#'), (0.1225, 0.154125, 'w'), (0.154125, 0.2175, 'ix'), (0.2175, 0.25, 'dcl'), (0.25, 0.3725, 's'), (0.3725, 0.4675, 'ah'), (0.4675, 0.4925, 'tcl'), (0.4925, 0.5875, 'ch'), (0.5875, 0.6225, 'ix'), (0.6225, 0.6675, 'n'), (0.6675, 0.8425, 'ae'), (0.8425, 0.98, 'kcl'), (0.98, 0.9925, 't'), (0.9925, 1.0575, 'ix'), (1.0575, 1.1435625, 'v'), (1.1435625, 1.180125, 'r'), (1.180125, 1.2175, 'ix'), (1.2175, 1.3576875, 'f'), (1.3576875, 1.40725, 'y'), (1.40725, 1.5025, 'ux'), (1.5025, 1.574375, 'zh'), (1.574375, 1.6925, 'el'), (1.6925, 1.76, 'bcl'), (1.76, 1.785, 'b'), (1.785, 1.8825, 'iy'), (1.8825, 1.9895, 'y'), (1.9895, 2.0775, 'ux'), (2.0775, 2.165, 's'), (2.165, 2.248, 'f'), (2.248, 2.3575, 'el'), (2.3575, 2.495, 'h#')]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PqI70asA5btU"
      },
      "source": [
        "charsiu.serve(audio=audio_path, save_to='sample.TextGrid')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O2hj-1ZJ1tfA"
      },
      "source": [
        "Direct inference with frame classification model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yG_eg8KJ1snD"
      },
      "source": [
        "charsiu = charsiu_predictive_aligner(aligner='charsiu/en_w2v2_fc_10ms')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qmkfJ9gD1tE0"
      },
      "source": [
        "alignment = charsiu.align(audio=audio_path)"
      ],
      "execution_count": 27,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IRQr20Av2Iq7",
        "outputId": "7b1802ac-0ea9-4582-ec92-b1e9282122c0",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "print(alignment)\n",
        "print('\\n Ground Truth \\n')\n",
        "print([(s/16000,e/16000,p) for s,e,p in zip(sample['phonetic_detail']['start'],sample['phonetic_detail']['stop'],sample['phonetic_detail']['utterance'])])"
      ],
      "execution_count": 28,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[(0.0, 0.08, '[SIL]'), (0.08, 0.15, 'W'), (0.15, 0.2, 'UH'), (0.2, 0.24, 'D'), (0.24, 0.38, 'S'), (0.38, 0.46, 'AH'), (0.46, 0.57, 'CH'), (0.57, 0.62, 'AH'), (0.62, 0.68, 'N'), (0.68, 0.83, 'AE'), (0.83, 0.94, 'K'), (0.94, 0.99, 'T'), (0.99, 1.05, 'IH'), (1.05, 1.13, 'V'), (1.13, 1.17, 'R'), (1.17, 1.21, 'IH'), (1.21, 1.22, 'AH'), (1.22, 1.34, 'F'), (1.34, 1.42, 'Y'), (1.42, 1.48, 'UW'), (1.48, 1.53, 'Z'), (1.53, 1.63, 'AH'), (1.63, 1.68, 'L'), (1.68, 1.78, 'B'), (1.78, 1.88, 'IY'), (1.88, 1.99, 'Y'), (1.99, 2.08, 'UW'), (2.08, 2.14, 'S'), (2.14, 2.23, 'F'), (2.23, 2.27, 'AH'), (2.27, 2.48, 'L')]\n",
            "\n",
            " Ground Truth \n",
            "\n",
            "[(0.0, 0.1225, 'h#'), (0.1225, 0.154125, 'w'), (0.154125, 0.2175, 'ix'), (0.2175, 0.25, 'dcl'), (0.25, 0.3725, 's'), (0.3725, 0.4675, 'ah'), (0.4675, 0.4925, 'tcl'), (0.4925, 0.5875, 'ch'), (0.5875, 0.6225, 'ix'), (0.6225, 0.6675, 'n'), (0.6675, 0.8425, 'ae'), (0.8425, 0.98, 'kcl'), (0.98, 0.9925, 't'), (0.9925, 1.0575, 'ix'), (1.0575, 1.1435625, 'v'), (1.1435625, 1.180125, 'r'), (1.180125, 1.2175, 'ix'), (1.2175, 1.3576875, 'f'), (1.3576875, 1.40725, 'y'), (1.40725, 1.5025, 'ux'), (1.5025, 1.574375, 'zh'), (1.574375, 1.6925, 'el'), (1.6925, 1.76, 'bcl'), (1.76, 1.785, 'b'), (1.785, 1.8825, 'iy'), (1.8825, 1.9895, 'y'), (1.9895, 2.0775, 'ux'), (2.0775, 2.165, 's'), (2.165, 2.248, 'f'), (2.248, 2.3575, 'el'), (2.3575, 2.495, 'h#')]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dwTpN_5c2Zm1",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "ac8b79e3-9c5c-439f-be60-2773c730e19f"
      },
      "source": [
        "charsiu.serve(audio=audio_path, save_to='sample.TextGrid')"
      ],
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Alignment output has been saved to sample.TextGrid\n"
          ]
        }
      ]
    }
  ]
}